{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadDataSet(fileName):\n",
    "    '''数据加载函数'''\n",
    "    dataArr = []\n",
    "    with open(fileName) as fr:\n",
    "        for line in fr.readlines():\n",
    "            curLine = line.strip().split('\\t')\n",
    "            fltLine = list(map(float,curLine)) \n",
    "            dataArr.append(fltLine)\n",
    "    return np.mat(dataArr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def binSplitDataSet(dataSet, feature, value):\n",
    "    '''特征切分函数\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "        feature: 切分特征\n",
    "        value: 切分依据\n",
    "    Return:\n",
    "        mat0、mat1: 切分后的数据\n",
    "    '''\n",
    "    mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]\n",
    "    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]\n",
    "    return mat0,mat1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[1., 0., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testMat = np.mat(eye(4))\n",
    "testMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testMat[np.nonzero(testMat[:, 1] < 0.5)[0],:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[0., 1., 0., 0.]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat0,mat1 = binSplitDataSet(testMat, 1, 0.5)\n",
    "mat0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[ 3.609800e-02,  1.550960e-01],\n",
       "        [ 9.933490e-01,  1.077553e+00],\n",
       "        [ 5.308970e-01,  8.934620e-01],\n",
       "        [ 7.123860e-01,  5.648580e-01],\n",
       "        [ 3.435540e-01, -3.717000e-01],\n",
       "        [ 9.801600e-02, -3.327600e-01],\n",
       "        [ 6.911150e-01,  8.343910e-01],\n",
       "        [ 9.135800e-02,  9.993500e-02],\n",
       "        [ 7.270980e-01,  1.000567e+00],\n",
       "        [ 9.519490e-01,  9.452550e-01],\n",
       "        [ 7.685960e-01,  7.602190e-01],\n",
       "        [ 5.413140e-01,  8.937480e-01],\n",
       "        [ 1.463660e-01,  3.428300e-02],\n",
       "        [ 6.731950e-01,  9.150770e-01],\n",
       "        [ 1.835100e-01,  1.848430e-01],\n",
       "        [ 3.395630e-01,  2.067830e-01],\n",
       "        [ 5.179210e-01,  1.493586e+00],\n",
       "        [ 7.037550e-01,  1.101678e+00],\n",
       "        [ 8.307000e-03,  6.997600e-02],\n",
       "        [ 2.439090e-01, -2.946700e-02],\n",
       "        [ 3.069640e-01, -1.773210e-01],\n",
       "        [ 3.649200e-02,  4.081550e-01],\n",
       "        [ 2.955110e-01,  2.882000e-03],\n",
       "        [ 8.375220e-01,  1.229373e+00],\n",
       "        [ 2.020540e-01, -8.774400e-02],\n",
       "        [ 9.193840e-01,  1.029889e+00],\n",
       "        [ 3.772010e-01, -2.435500e-01],\n",
       "        [ 8.148250e-01,  1.095206e+00],\n",
       "        [ 6.112700e-01,  9.820360e-01],\n",
       "        [ 7.224300e-02, -4.209830e-01],\n",
       "        [ 4.102300e-01,  3.317220e-01],\n",
       "        [ 8.690770e-01,  1.114825e+00],\n",
       "        [ 6.205990e-01,  1.334421e+00],\n",
       "        [ 1.011490e-01,  6.883400e-02],\n",
       "        [ 8.208020e-01,  1.325907e+00],\n",
       "        [ 5.200440e-01,  9.619830e-01],\n",
       "        [ 4.881300e-01, -9.779100e-02],\n",
       "        [ 8.198230e-01,  8.352640e-01],\n",
       "        [ 9.750220e-01,  6.735790e-01],\n",
       "        [ 9.531120e-01,  1.064690e+00],\n",
       "        [ 4.759760e-01, -1.637070e-01],\n",
       "        [ 2.731470e-01, -4.552190e-01],\n",
       "        [ 8.045860e-01,  9.240330e-01],\n",
       "        [ 7.479500e-02, -3.496920e-01],\n",
       "        [ 6.253360e-01,  6.236960e-01],\n",
       "        [ 6.562180e-01,  9.585060e-01],\n",
       "        [ 8.340780e-01,  1.010580e+00],\n",
       "        [ 7.819300e-01,  1.074488e+00],\n",
       "        [ 9.849000e-03,  5.659400e-02],\n",
       "        [ 3.022170e-01, -1.486500e-01],\n",
       "        [ 6.782870e-01,  9.077270e-01],\n",
       "        [ 1.805060e-01,  1.036760e-01],\n",
       "        [ 1.936410e-01, -3.275890e-01],\n",
       "        [ 3.434790e-01,  1.752640e-01],\n",
       "        [ 1.458090e-01,  1.369790e-01],\n",
       "        [ 9.967570e-01,  1.035533e+00],\n",
       "        [ 5.902100e-01,  1.336661e+00],\n",
       "        [ 2.380700e-01, -3.584590e-01],\n",
       "        [ 5.613620e-01,  1.070529e+00],\n",
       "        [ 3.775970e-01,  8.850500e-02],\n",
       "        [ 9.914200e-02,  2.528000e-02],\n",
       "        [ 5.395580e-01,  1.053846e+00],\n",
       "        [ 7.902400e-01,  5.332140e-01],\n",
       "        [ 2.422040e-01,  2.093590e-01],\n",
       "        [ 1.523240e-01,  1.328580e-01],\n",
       "        [ 2.526490e-01, -5.561300e-02],\n",
       "        [ 8.959300e-01,  1.077275e+00],\n",
       "        [ 1.333000e-01, -2.231430e-01],\n",
       "        [ 5.597630e-01,  1.253151e+00],\n",
       "        [ 6.436650e-01,  1.024241e+00],\n",
       "        [ 8.772410e-01,  7.970050e-01],\n",
       "        [ 6.137650e-01,  1.621091e+00],\n",
       "        [ 6.457620e-01,  1.026886e+00],\n",
       "        [ 6.513760e-01,  1.315384e+00],\n",
       "        [ 6.977180e-01,  1.212434e+00],\n",
       "        [ 7.425270e-01,  1.087056e+00],\n",
       "        [ 9.010560e-01,  1.055900e+00],\n",
       "        [ 3.623140e-01, -5.564640e-01],\n",
       "        [ 9.482680e-01,  6.318620e-01],\n",
       "        [ 2.340000e-04,  6.090300e-02],\n",
       "        [ 7.500780e-01,  9.062910e-01],\n",
       "        [ 3.254120e-01, -2.192450e-01],\n",
       "        [ 7.268280e-01,  1.017112e+00],\n",
       "        [ 3.480130e-01,  4.893900e-02],\n",
       "        [ 4.581210e-01, -6.145600e-02],\n",
       "        [ 2.807380e-01, -2.288800e-01],\n",
       "        [ 5.677040e-01,  9.690580e-01],\n",
       "        [ 7.509180e-01,  7.481040e-01],\n",
       "        [ 5.758050e-01,  8.990900e-01],\n",
       "        [ 5.079400e-01,  1.107265e+00],\n",
       "        [ 7.176900e-02, -1.109460e-01],\n",
       "        [ 5.535200e-01,  1.391273e+00],\n",
       "        [ 4.011520e-01, -1.216400e-01],\n",
       "        [ 4.066490e-01, -3.663170e-01],\n",
       "        [ 6.521210e-01,  1.004346e+00],\n",
       "        [ 3.478370e-01, -1.534050e-01],\n",
       "        [ 8.193100e-02, -2.697560e-01],\n",
       "        [ 8.216480e-01,  1.280895e+00],\n",
       "        [ 4.801400e-02,  6.449600e-02],\n",
       "        [ 1.309620e-01,  1.842410e-01],\n",
       "        [ 7.734220e-01,  1.125943e+00],\n",
       "        [ 7.896250e-01,  5.526140e-01],\n",
       "        [ 9.699400e-02,  2.271670e-01],\n",
       "        [ 6.257910e-01,  1.244731e+00],\n",
       "        [ 5.895750e-01,  1.185812e+00],\n",
       "        [ 3.231810e-01,  1.808110e-01],\n",
       "        [ 8.224430e-01,  1.086648e+00],\n",
       "        [ 3.603230e-01, -2.048300e-01],\n",
       "        [ 9.501530e-01,  1.022906e+00],\n",
       "        [ 5.275050e-01,  8.795600e-01],\n",
       "        [ 8.600490e-01,  7.174900e-01],\n",
       "        [ 7.044000e-03,  9.415000e-02],\n",
       "        [ 4.383670e-01,  3.401400e-02],\n",
       "        [ 5.745730e-01,  1.066130e+00],\n",
       "        [ 5.366890e-01,  8.672840e-01],\n",
       "        [ 7.821670e-01,  8.860490e-01],\n",
       "        [ 9.898880e-01,  7.442070e-01],\n",
       "        [ 7.614740e-01,  1.058262e+00],\n",
       "        [ 9.854250e-01,  1.227946e+00],\n",
       "        [ 1.325430e-01, -3.293720e-01],\n",
       "        [ 3.469860e-01, -1.503890e-01],\n",
       "        [ 7.687840e-01,  8.997050e-01],\n",
       "        [ 8.489210e-01,  1.170959e+00],\n",
       "        [ 4.492800e-01,  6.909800e-02],\n",
       "        [ 6.617200e-02,  5.243900e-02],\n",
       "        [ 8.137190e-01,  7.066010e-01],\n",
       "        [ 6.619230e-01,  7.670400e-01],\n",
       "        [ 5.294910e-01,  1.022206e+00],\n",
       "        [ 8.464550e-01,  7.200300e-01],\n",
       "        [ 4.486560e-01,  2.697400e-02],\n",
       "        [ 7.950720e-01,  9.657210e-01],\n",
       "        [ 1.181560e-01, -7.740900e-02],\n",
       "        [ 8.424800e-02, -1.954700e-02],\n",
       "        [ 8.458150e-01,  9.526170e-01],\n",
       "        [ 5.769460e-01,  1.234129e+00],\n",
       "        [ 7.720830e-01,  1.299018e+00],\n",
       "        [ 6.966480e-01,  8.454230e-01],\n",
       "        [ 5.950120e-01,  1.213435e+00],\n",
       "        [ 6.486750e-01,  1.287407e+00],\n",
       "        [ 8.970940e-01,  1.240209e+00],\n",
       "        [ 5.529900e-01,  1.036158e+00],\n",
       "        [ 3.329820e-01,  2.100840e-01],\n",
       "        [ 6.561500e-02, -3.069700e-01],\n",
       "        [ 2.786610e-01,  2.536280e-01],\n",
       "        [ 7.731680e-01,  1.140917e+00],\n",
       "        [ 2.036930e-01, -6.403600e-02],\n",
       "        [ 3.556880e-01, -1.193990e-01],\n",
       "        [ 9.888520e-01,  1.069062e+00],\n",
       "        [ 5.187350e-01,  1.037179e+00],\n",
       "        [ 5.145630e-01,  1.156648e+00],\n",
       "        [ 9.764140e-01,  8.629110e-01],\n",
       "        [ 9.190740e-01,  1.123413e+00],\n",
       "        [ 6.977770e-01,  8.278050e-01],\n",
       "        [ 9.280970e-01,  8.832250e-01],\n",
       "        [ 9.002720e-01,  9.968710e-01],\n",
       "        [ 3.441020e-01, -6.153900e-02],\n",
       "        [ 1.480490e-01,  2.042980e-01],\n",
       "        [ 1.300520e-01, -2.616700e-02],\n",
       "        [ 3.020010e-01,  3.171350e-01],\n",
       "        [ 3.371000e-01,  2.633200e-02],\n",
       "        [ 3.149240e-01, -1.952000e-03],\n",
       "        [ 2.696810e-01, -1.659710e-01],\n",
       "        [ 1.960050e-01, -4.884700e-02],\n",
       "        [ 1.290610e-01,  3.051070e-01],\n",
       "        [ 9.367830e-01,  1.026258e+00],\n",
       "        [ 3.055400e-01, -1.159910e-01],\n",
       "        [ 6.839210e-01,  1.414382e+00],\n",
       "        [ 6.223980e-01,  7.663300e-01],\n",
       "        [ 9.025320e-01,  8.616010e-01],\n",
       "        [ 7.125030e-01,  9.334900e-01],\n",
       "        [ 5.900620e-01,  7.055310e-01],\n",
       "        [ 7.231200e-01,  1.307248e+00],\n",
       "        [ 1.882180e-01,  1.136850e-01],\n",
       "        [ 6.436010e-01,  7.825520e-01],\n",
       "        [ 5.202070e-01,  1.209557e+00],\n",
       "        [ 2.331150e-01, -3.481470e-01],\n",
       "        [ 4.656250e-01, -1.529400e-01],\n",
       "        [ 8.845120e-01,  1.117833e+00],\n",
       "        [ 6.632000e-01,  7.016340e-01],\n",
       "        [ 2.688570e-01,  7.344700e-02],\n",
       "        [ 7.292340e-01,  9.319560e-01],\n",
       "        [ 4.296640e-01, -1.886590e-01],\n",
       "        [ 7.371890e-01,  1.200781e+00],\n",
       "        [ 3.785950e-01, -2.960940e-01],\n",
       "        [ 9.301730e-01,  1.035645e+00],\n",
       "        [ 7.743010e-01,  8.367630e-01],\n",
       "        [ 2.739400e-01, -8.571300e-02],\n",
       "        [ 8.244420e-01,  1.082153e+00],\n",
       "        [ 6.260110e-01,  8.405440e-01],\n",
       "        [ 6.793900e-01,  1.307217e+00],\n",
       "        [ 5.782520e-01,  9.218850e-01],\n",
       "        [ 7.855410e-01,  1.165296e+00],\n",
       "        [ 5.974090e-01,  9.747700e-01],\n",
       "        [ 1.408300e-02, -1.325250e-01],\n",
       "        [ 6.638700e-01,  1.187129e+00],\n",
       "        [ 5.523810e-01,  1.369630e+00],\n",
       "        [ 6.838860e-01,  9.999850e-01],\n",
       "        [ 2.103340e-01, -6.899000e-03],\n",
       "        [ 6.045290e-01,  1.212685e+00],\n",
       "        [ 2.507440e-01,  4.629700e-02]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myMat = loadDataSet('ex00.txt')\n",
    "myMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200, 2)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.shape(myMat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def regLeaf(dataSet):\n",
    "    '''叶节点生成函数\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "    Return:\n",
    "        目标变量均值\n",
    "    '''\n",
    "    return np.mean(dataSet[:,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def regErr(dataSet):\n",
    "    '''误差估计函数\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "    Return:\n",
    "        目标变量平方误差\n",
    "    '''\n",
    "    # np.var()计算的是均方差，需乘以样本数得总方差\n",
    "    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):\n",
    "    '''树构建函数\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "        leafType: 叶节点类型（默认为回归树，也可以设定为模型树）\n",
    "        errType:误差估计函数\n",
    "        ops: 控制函数的停止时机\n",
    "    Return:\n",
    "        retTree构建好的树\n",
    "    '''\n",
    "    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)\n",
    "    if feat == None: return val \n",
    "    retTree = {}\n",
    "    retTree['spInd'] = feat\n",
    "    retTree['spVal'] = val\n",
    "    lSet, rSet = binSplitDataSet(dataSet, feat, val)\n",
    "    retTree['left'] = createTree(lSet, leafType, errType, ops)\n",
    "    retTree['right'] = createTree(rSet, leafType, errType, ops)\n",
    "    return retTree  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):\n",
    "    '''找到最优的切分方式函数\n",
    "    Args:\n",
    "        dataSet: 数据集\n",
    "        leafType: 叶节点类型（默认为回归树，也可以设定为模型树）\n",
    "        errType:误差估计函数\n",
    "        ops: 控制函数的停止时机\n",
    "    Return:\n",
    "        bestIndex: 最佳切分特征列\n",
    "        bestValue: 最佳切分值\n",
    "    '''\n",
    "    # tolS是容许的误差下降值，tolN是切分的最少样本数\n",
    "    tolS = ops[0]; tolN = ops[1]\n",
    "    # 如果数据集中各特征值都一样，则为叶子节点\n",
    "    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1\n",
    "        return None, leafType(dataSet)\n",
    "    m,n = np.shape(dataSet)\n",
    "    # 计算各特征值总方差\n",
    "    S = errType(dataSet)\n",
    "    bestS = np.inf; bestIndex = 0; bestValue = 0\n",
    "    # 根据总方差来寻找到最佳的切分特征，以及切分依据值\n",
    "    for featIndex in range(n-1):\n",
    "        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):\n",
    "            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)\n",
    "            # 如果划分样本数小于tolN，换个值继续尝试\n",
    "            if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue\n",
    "            newS = errType(mat0) + errType(mat1)\n",
    "            if newS < bestS: \n",
    "                bestIndex = featIndex\n",
    "                bestValue = splitVal\n",
    "                bestS = newS\n",
    "    # 如果误差下降小于阈值，不做切分，视为叶子节点\n",
    "    if (S - bestS) < tolS: \n",
    "        return None, leafType(dataSet) \n",
    "    # 如果按照此切分方式所得数据集很小，不做切分，视为叶子节点\n",
    "    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)\n",
    "    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):  #exit cond 3\n",
    "        return None, leafType(dataSet)\n",
    "    \n",
    "    return bestIndex,bestValue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'left': 1.0180967672413792,\n",
       " 'right': -0.04465028571428572,\n",
       " 'spInd': 0,\n",
       " 'spVal': 0.48813}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "createTree(myMat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'left': {'left': {'left': 3.9871632,\n",
       "   'right': 2.9836209534883724,\n",
       "   'spInd': 1,\n",
       "   'spVal': 0.797583},\n",
       "  'right': 1.980035071428571,\n",
       "  'spInd': 1,\n",
       "  'spVal': 0.582002},\n",
       " 'right': {'left': 1.0289583666666666,\n",
       "  'right': -0.023838155555555553,\n",
       "  'spInd': 1,\n",
       "  'spVal': 0.197834},\n",
       " 'spInd': 1,\n",
       " 'spVal': 0.39435}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myMat1 = loadDataSet('ex0.txt')\n",
    "createTree(myMat1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 回归树剪枝操作\n",
    "def isTree(obj):\n",
    "    '''判断是否是一棵树'''\n",
    "    return (type(obj).__name__=='dict')\n",
    "\n",
    "def getMean(tree):\n",
    "    '''从上往下遍历树，直到叶节点，返回两节点均值'''\n",
    "    if isTree(tree['right']): tree['right'] = getMean(tree['right'])\n",
    "    if isTree(tree['left']): tree['left'] = getMean(tree['left'])\n",
    "    return (tree['left']+tree['right'])/2.0\n",
    "    \n",
    "def prune(tree, testData):\n",
    "    if np.shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree\n",
    "    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them\n",
    "        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])\n",
    "    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)\n",
    "    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)\n",
    "    #if they are now both leafs, see if we can merge them\n",
    "    if not isTree(tree['left']) and not isTree(tree['right']):\n",
    "        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])\n",
    "        errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'],2)) +\\\n",
    "            np.sum(np.power(rSet[:,-1] - tree['right'],2))\n",
    "        treeMean = (tree['left']+tree['right'])/2.0\n",
    "        errorMerge = np.sum(np.power(testData[:,-1] - treeMean,2))\n",
    "        if errorMerge < errorNoMerge: \n",
    "            print(\"merging\")\n",
    "            return treeMean\n",
    "        else: return tree\n",
    "    else: return tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[ 4.21862000e-01,  1.08302410e+01],\n",
       "        [ 1.05349000e-01, -2.24161100e+00],\n",
       "        [ 1.55196000e-01,  2.18729760e+01],\n",
       "        [ 1.61152000e-01,  2.01541800e+00],\n",
       "        [ 3.82632000e-01, -3.87789790e+01],\n",
       "        [ 1.77100000e-02,  2.01091130e+01],\n",
       "        [ 1.29656000e-01,  1.52668870e+01],\n",
       "        [ 6.13926000e-01,  1.11900063e+02],\n",
       "        [ 4.09277000e-01,  1.87473100e+00],\n",
       "        [ 8.07556000e-01,  1.11223754e+02],\n",
       "        [ 5.93722000e-01,  1.33835486e+02],\n",
       "        [ 9.53239000e-01,  1.10465070e+02],\n",
       "        [ 2.57402000e-01,  1.53328990e+01],\n",
       "        [ 6.45385000e-01,  9.39830540e+01],\n",
       "        [ 5.63460000e-01,  9.36452770e+01],\n",
       "        [ 4.08338000e-01, -3.07198780e+01],\n",
       "        [ 8.74394000e-01,  9.18735050e+01],\n",
       "        [ 2.63805000e-01, -1.92752000e-01],\n",
       "        [ 4.11198000e-01,  1.07511180e+01],\n",
       "        [ 4.49884000e-01,  9.21190100e+00],\n",
       "        [ 6.46315000e-01,  1.13533660e+02],\n",
       "        [ 6.73718000e-01,  1.25135638e+02],\n",
       "        [ 8.05148000e-01,  1.13300462e+02],\n",
       "        [ 7.59327000e-01,  7.26685720e+01],\n",
       "        [ 5.19172000e-01,  8.21316980e+01],\n",
       "        [ 7.41031000e-01,  1.06777146e+02],\n",
       "        [ 3.09370000e-02,  9.85912700e+00],\n",
       "        [ 2.68848000e-01, -3.41379550e+01],\n",
       "        [ 4.74901000e-01, -1.12013010e+01],\n",
       "        [ 5.88266000e-01,  1.20501998e+02],\n",
       "        [ 8.93936000e-01,  1.42826476e+02],\n",
       "        [ 8.70990000e-01,  1.05751746e+02],\n",
       "        [ 4.30763000e-01,  3.91462580e+01],\n",
       "        [ 5.76650000e-02,  1.53718970e+01],\n",
       "        [ 1.00076000e-01,  9.13176100e+00],\n",
       "        [ 9.80716000e-01,  1.16145896e+02],\n",
       "        [ 2.35289000e-01, -1.36912240e+01],\n",
       "        [ 2.28098000e-01,  1.60891510e+01],\n",
       "        [ 6.22248000e-01,  9.93455510e+01],\n",
       "        [ 4.01467000e-01, -1.69438300e+00],\n",
       "        [ 9.60334000e-01,  1.10795415e+02],\n",
       "        [ 3.12140000e-02, -5.33004200e+00],\n",
       "        [ 5.04228000e-01,  9.60035250e+01],\n",
       "        [ 7.79660000e-01,  7.59215820e+01],\n",
       "        [ 5.04496000e-01,  1.01341462e+02],\n",
       "        [ 8.50974000e-01,  9.62930640e+01],\n",
       "        [ 7.01119000e-01,  1.02333839e+02],\n",
       "        [ 1.91551000e-01,  5.07232600e+00],\n",
       "        [ 6.67116000e-01,  9.23100190e+01],\n",
       "        [ 5.55584000e-01,  8.03671290e+01],\n",
       "        [ 6.80006000e-01,  1.32965442e+02],\n",
       "        [ 3.93899000e-01,  3.86052830e+01],\n",
       "        [ 4.89400000e-02, -9.86187100e+00],\n",
       "        [ 9.63282000e-01,  1.15407485e+02],\n",
       "        [ 6.55496000e-01,  1.04269918e+02],\n",
       "        [ 5.76463000e-01,  1.41127267e+02],\n",
       "        [ 6.75708000e-01,  9.62279960e+01],\n",
       "        [ 8.53457000e-01,  1.14252288e+02],\n",
       "        [ 3.93300000e-03, -1.21828610e+01],\n",
       "        [ 5.49512000e-01,  9.79272240e+01],\n",
       "        [ 2.18967000e-01, -4.71246200e+00],\n",
       "        [ 6.59972000e-01,  1.20950439e+02],\n",
       "        [ 8.25600000e-03,  8.02681600e+00],\n",
       "        [ 9.95000000e-02, -1.43184340e+01],\n",
       "        [ 3.52215000e-01, -3.74754600e+00],\n",
       "        [ 8.74926000e-01,  8.92473560e+01],\n",
       "        [ 6.35084000e-01,  9.94960590e+01],\n",
       "        [ 3.96410000e-02,  1.41471090e+01],\n",
       "        [ 6.65111000e-01,  1.03298719e+02],\n",
       "        [ 1.56583000e-01, -2.54070300e+00],\n",
       "        [ 6.48843000e-01,  1.19333019e+02],\n",
       "        [ 8.93237000e-01,  9.52095850e+01],\n",
       "        [ 1.28807000e-01,  5.55847900e+00],\n",
       "        [ 1.37438000e-01,  5.56768500e+00],\n",
       "        [ 6.30538000e-01,  9.84627920e+01],\n",
       "        [ 2.96084000e-01, -4.17994380e+01],\n",
       "        [ 6.32099000e-01,  8.48950980e+01],\n",
       "        [ 9.87681000e-01,  1.06726447e+02],\n",
       "        [ 7.44909000e-01,  1.11279705e+02],\n",
       "        [ 8.62030000e-01,  1.04581156e+02],\n",
       "        [ 8.06490000e-02, -7.67998500e+00],\n",
       "        [ 8.31277000e-01,  5.90533560e+01],\n",
       "        [ 1.98716000e-01,  2.68788010e+01],\n",
       "        [ 8.60932000e-01,  9.06329300e+01],\n",
       "        [ 8.83250000e-01,  9.27595950e+01],\n",
       "        [ 8.18003000e-01,  1.10272219e+02],\n",
       "        [ 9.49216000e-01,  1.15200237e+02],\n",
       "        [ 4.60078000e-01, -3.59579810e+01],\n",
       "        [ 5.61077000e-01,  9.35457610e+01],\n",
       "        [ 8.63767000e-01,  1.14125786e+02],\n",
       "        [ 4.76891000e-01, -2.97740600e+01],\n",
       "        [ 5.37826000e-01,  8.15879220e+01],\n",
       "        [ 6.86224000e-01,  1.10911198e+02],\n",
       "        [ 9.82327000e-01,  1.19114523e+02],\n",
       "        [ 9.44453000e-01,  9.20334810e+01],\n",
       "        [ 7.82270000e-02,  3.02168730e+01],\n",
       "        [ 7.82937000e-01,  9.25886460e+01],\n",
       "        [ 4.65886000e-01,  2.22213900e+00],\n",
       "        [ 8.85024000e-01,  9.02478900e+01],\n",
       "        [ 1.86077000e-01,  7.14441500e+00],\n",
       "        [ 9.15828000e-01,  8.40100740e+01],\n",
       "        [ 7.96649000e-01,  1.15572156e+02],\n",
       "        [ 1.27821000e-01,  2.89336880e+01],\n",
       "        [ 4.33429000e-01,  6.78257500e+00],\n",
       "        [ 9.46796000e-01,  1.08574116e+02],\n",
       "        [ 3.86915000e-01, -1.74046010e+01],\n",
       "        [ 5.61192000e-01,  9.21427000e+01],\n",
       "        [ 1.82490000e-01,  1.07646160e+01],\n",
       "        [ 8.78792000e-01,  9.52894760e+01],\n",
       "        [ 3.81342000e-01, -6.17746400e+00],\n",
       "        [ 3.58474000e-01, -1.17317540e+01],\n",
       "        [ 2.70647000e-01,  1.37932010e+01],\n",
       "        [ 4.88904000e-01, -1.76418320e+01],\n",
       "        [ 1.06773000e-01,  5.68475700e+00],\n",
       "        [ 2.70112000e-01,  4.33567500e+00],\n",
       "        [ 7.54985000e-01,  7.58604330e+01],\n",
       "        [ 5.85174000e-01,  1.11640154e+02],\n",
       "        [ 4.58821000e-01,  1.20296920e+01],\n",
       "        [ 2.18017000e-01, -2.62348720e+01],\n",
       "        [ 5.83887000e-01,  9.94138500e+01],\n",
       "        [ 9.23626000e-01,  1.07802298e+02],\n",
       "        [ 8.33620000e-01,  1.04179678e+02],\n",
       "        [ 8.70691000e-01,  9.31325910e+01],\n",
       "        [ 2.49896000e-01, -8.61840400e+00],\n",
       "        [ 7.48230000e-01,  1.09160652e+02],\n",
       "        [ 1.93650000e-02,  3.40488840e+01],\n",
       "        [ 8.37588000e-01,  1.01239275e+02],\n",
       "        [ 5.29251000e-01,  1.15514729e+02],\n",
       "        [ 7.42898000e-01,  6.70387710e+01],\n",
       "        [ 5.22034000e-01,  6.41607990e+01],\n",
       "        [ 4.98982000e-01,  3.98306100e+00],\n",
       "        [ 4.79439000e-01,  2.43559080e+01],\n",
       "        [ 3.14834000e-01, -1.42562000e+01],\n",
       "        [ 7.53251000e-01,  8.50170920e+01],\n",
       "        [ 4.79362000e-01, -1.74804460e+01],\n",
       "        [ 9.50593000e-01,  9.90727840e+01],\n",
       "        [ 7.18623000e-01,  5.80802560e+01],\n",
       "        [ 2.18720000e-01, -1.96055930e+01],\n",
       "        [ 6.64113000e-01,  9.44371590e+01],\n",
       "        [ 9.42900000e-01,  1.31725134e+02],\n",
       "        [ 3.14226000e-01,  1.89048710e+01],\n",
       "        [ 2.84509000e-01,  1.17793460e+01],\n",
       "        [ 4.96200000e-03, -1.46241760e+01],\n",
       "        [ 2.24087000e-01, -5.05476490e+01],\n",
       "        [ 9.74331000e-01,  1.12822725e+02],\n",
       "        [ 8.94610000e-01,  1.12863995e+02],\n",
       "        [ 1.67350000e-01,  7.33800000e-02],\n",
       "        [ 7.53644000e-01,  1.05024456e+02],\n",
       "        [ 6.32241000e-01,  1.08625812e+02],\n",
       "        [ 3.14189000e-01, -6.09079700e+00],\n",
       "        [ 9.65527000e-01,  8.74183430e+01],\n",
       "        [ 8.20919000e-01,  9.46105380e+01],\n",
       "        [ 1.44107000e-01, -4.74838700e+00],\n",
       "        [ 7.25560000e-02, -5.68200800e+00],\n",
       "        [ 2.44700000e-03,  2.96857140e+01],\n",
       "        [ 8.51007000e-01,  7.96323760e+01],\n",
       "        [ 4.58024000e-01, -1.23260260e+01],\n",
       "        [ 6.27503000e-01,  1.39458881e+02],\n",
       "        [ 4.22259000e-01, -2.98274050e+01],\n",
       "        [ 7.14659000e-01,  6.34802710e+01],\n",
       "        [ 6.72320000e-01,  9.36085540e+01],\n",
       "        [ 4.98592000e-01,  3.71129750e+01],\n",
       "        [ 6.98906000e-01,  9.62828450e+01],\n",
       "        [ 8.61441000e-01,  9.96992300e+01],\n",
       "        [ 1.12425000e-01, -1.24199090e+01],\n",
       "        [ 1.64784000e-01,  5.24470400e+00],\n",
       "        [ 4.81531000e-01, -1.80704970e+01],\n",
       "        [ 3.75482000e-01,  1.77941100e+00],\n",
       "        [ 8.93250000e-02, -1.42167550e+01],\n",
       "        [ 3.66090000e-02, -6.26437200e+00],\n",
       "        [ 9.45004000e-01,  5.47235630e+01],\n",
       "        [ 1.36608000e-01,  1.49709360e+01],\n",
       "        [ 2.92285000e-01, -4.17237110e+01],\n",
       "        [ 2.91950000e-02, -6.60279000e-01],\n",
       "        [ 9.98307000e-01,  1.00124230e+02],\n",
       "        [ 3.03928000e-01, -5.49226400e+00],\n",
       "        [ 9.57863000e-01,  1.17824392e+02],\n",
       "        [ 8.15089000e-01,  1.13377704e+02],\n",
       "        [ 4.66399000e-01, -1.02498740e+01],\n",
       "        [ 8.76693000e-01,  1.15617275e+02],\n",
       "        [ 5.36121000e-01,  1.02997087e+02],\n",
       "        [ 3.73984000e-01, -3.73599360e+01],\n",
       "        [ 5.65162000e-01,  7.49674760e+01],\n",
       "        [ 8.54120000e-02, -2.14495630e+01],\n",
       "        [ 6.86411000e-01,  6.48596200e+01],\n",
       "        [ 9.08752000e-01,  1.07983366e+02],\n",
       "        [ 9.82829000e-01,  9.80054240e+01],\n",
       "        [ 5.27660000e-02, -4.21395020e+01],\n",
       "        [ 7.77552000e-01,  9.18993400e+01],\n",
       "        [ 3.74316000e-01, -3.52250100e+00],\n",
       "        [ 6.02310000e-02,  1.00082270e+01],\n",
       "        [ 5.26225000e-01,  8.73177220e+01],\n",
       "        [ 5.83872000e-01,  6.71044330e+01],\n",
       "        [ 2.38276000e-01,  1.06151590e+01],\n",
       "        [ 6.78747000e-01,  6.06242730e+01],\n",
       "        [ 6.76490000e-02,  1.59473980e+01],\n",
       "        [ 5.30182000e-01,  1.05030933e+02],\n",
       "        [ 8.69389000e-01,  1.04969996e+02],\n",
       "        [ 6.98410000e-01,  7.54604170e+01],\n",
       "        [ 5.49430000e-01,  8.25580680e+01]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myDatTest = loadDataSet('ex2test.txt')\n",
    "myDatTest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "myTree = createTree(myMat1, ops=(0,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 4.237235,\n",
       "         'right': 3.7709250546875,\n",
       "         'spInd': 1,\n",
       "         'spVal': 0.998533},\n",
       "        'right': 4.040275150390625,\n",
       "        'spInd': 1,\n",
       "        'spVal': 0.952758},\n",
       "       'right': 4.40195,\n",
       "       'spInd': 1,\n",
       "       'spVal': 0.872288},\n",
       "      'right': 3.79251175,\n",
       "      'spInd': 1,\n",
       "      'spVal': 0.867298},\n",
       "     'right': 4.308686,\n",
       "     'spInd': 1,\n",
       "     'spVal': 0.832693},\n",
       "    'right': 3.768841,\n",
       "    'spInd': 1,\n",
       "    'spVal': 0.819006},\n",
       "   'right': 3.08343756640625,\n",
       "   'spInd': 1,\n",
       "   'spVal': 0.797583},\n",
       "  'right': 1.9286052814941406,\n",
       "  'spInd': 1,\n",
       "  'spVal': 0.582002},\n",
       " 'right': {'left': 1.245535791015625,\n",
       "  'right': {'left': 0.058443562500000004,\n",
       "   'right': {'left': -0.2779215,\n",
       "    'right': {'left': {'left': {'left': {'left': -0.054863743164062506,\n",
       "        'right': {'left': {'left': 0.025216,\n",
       "          'right': {'left': 0.068935,\n",
       "           'right': 0.072131,\n",
       "           'spInd': 1,\n",
       "           'spVal': 0.076386},\n",
       "          'spInd': 1,\n",
       "          'spVal': 0.081306},\n",
       "         'right': 0.006014,\n",
       "         'spInd': 1,\n",
       "         'spVal': 0.071476},\n",
       "        'spInd': 1,\n",
       "        'spVal': 0.084302},\n",
       "       'right': 0.22814800000000002,\n",
       "       'spInd': 1,\n",
       "       'spVal': 0.057534},\n",
       "      'right': -0.1930906484375,\n",
       "      'spInd': 1,\n",
       "      'spVal': 0.052031},\n",
       "     'right': 0.188975,\n",
       "     'spInd': 1,\n",
       "     'spVal': 0.004327},\n",
       "    'spInd': 1,\n",
       "    'spVal': 0.143143},\n",
       "   'spInd': 1,\n",
       "   'spVal': 0.148654},\n",
       "  'spInd': 1,\n",
       "  'spVal': 0.197834},\n",
       " 'spInd': 1,\n",
       " 'spVal': 0.39435}"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prune(myTree, myDatTest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linearSolve(dataSet):   \n",
    "    '''格式化数据，便于进行线性回归'''\n",
    "    m,n = np.shape(dataSet)\n",
    "    X = np.mat(np.ones((m,n))); Y = np.mat(np.ones((m,1)))#create a copy of data with 1 in 0th postion\n",
    "    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y\n",
    "    xTx = X.T*X\n",
    "    if np.linalg.det(xTx) == 0.0:\n",
    "        raise NameError('This matrix is singular, cannot do inverse,\\n\\\n",
    "        try increasing the second value of ops')\n",
    "    # 正规方程求解参数\n",
    "    ws = xTx.I * (X.T * Y)\n",
    "    return ws,X,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modelLeaf(dataSet):\n",
    "    #每个叶子节点是一个线性方程\n",
    "    ws,X,Y = linearSolve(dataSet)\n",
    "    return ws\n",
    "\n",
    "def modelErr(dataSet):\n",
    "    ws,X,Y = linearSolve(dataSet)\n",
    "    yHat = X * ws\n",
    "    return np.sum(np.power(Y - yHat,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "myMat2 = loadDataSet('exp2.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'left': matrix([[1.69855694e-03],\n",
       "         [1.19647739e+01]]), 'right': matrix([[3.46877936],\n",
       "         [1.18521743]]), 'spInd': 0, 'spVal': 0.285477}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "createTree(myMat2, modelLeaf, modelErr, (1,10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试哪种模型效果好\n",
    "def regTreeEval(model, inDat):\n",
    "    return float(model)\n",
    "\n",
    "def modelTreeEval(model, inDat):\n",
    "    n = np.shape(inDat)[1]\n",
    "    X = np.mat(np.ones((1,n+1)))\n",
    "    X[:,1:n+1]=inDat\n",
    "    return float(X*model)\n",
    "\n",
    "def treeForeCast(tree, inData, modelEval=regTreeEval):\n",
    "    if not isTree(tree): return modelEval(tree, inData)\n",
    "    if inData[tree['spInd']] > tree['spVal']:\n",
    "        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)\n",
    "        else: return modelEval(tree['left'], inData)\n",
    "    else:\n",
    "        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)\n",
    "        else: return modelEval(tree['right'], inData)\n",
    "        \n",
    "def createForeCast(tree, testData, modelEval=regTreeEval):\n",
    "    m=len(testData)\n",
    "    yHat = np.mat(np.zeros((m,1)))\n",
    "    for i in range(m):\n",
    "        yHat[i,0] = treeForeCast(tree, testData[i], modelEval)\n",
    "    return yHat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9640852318222141"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainMat = loadDataSet('bikeSpeedVsIq_train.txt')\n",
    "testMat = loadDataSet('bikeSpeedVsIq_test.txt')\n",
    "myTree = createTree(trainMat, ops=(1,20))\n",
    "yHat = createForeCast(myTree, testMat[:, 0])\n",
    "np.corrcoef(yHat, testMat[:, 1], rowvar=0)[0,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[37.58916794],\n",
       "        [ 6.18978355]])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ws, X, Y = linearSolve(trainMat)\n",
    "ws"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9434684235674767"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(np.shape(testMat)[0]):\n",
    "    yHat[i] = testMat[i, 0] + ws[0, 0]\n",
    "    \n",
    "np.corrcoef(yHat, testMat[:, 1], rowvar=0)[0,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
